import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from scipy import optimize
from scipy.ndimage import gaussian_filter
from io import BytesIO

import plotly.express as px
import plotly.graph_objects as go
import numba
import gdspy  # Mock-up for GDSII handling (using gdspy library)
import svgwrite

# Placeholder functions for GDSII and SVG handling
def gdspy_to_array(gds_lib, grid_size):
    # Placeholder function to convert GDSII data to a numpy array mask pattern
    st.error("GDSII file handling is not implemented.")
    st.stop()

def svg_to_array(svg_file, grid_size):
    # Placeholder function to convert SVG data to a numpy array mask pattern
    st.error("SVG file handling is not implemented.")
    st.stop()

# Implement actual Zernike polynomials
def zernike(n, m, rho, phi):
    if (n - abs(m)) % 2 != 0:
        return np.zeros_like(rho)
    radial = np.zeros_like(rho)
    for k in range((n - abs(m)) // 2 + 1):
        num = (-1)**k * np.math.factorial(n - k)
        denom = (np.math.factorial(k) *
                 np.math.factorial((n + abs(m)) // 2 - k) *
                 np.math.factorial((n - abs(m)) // 2 - k))
        radial += num / denom * rho**(n - 2*k)
    if m >= 0:
        return radial * np.cos(m * phi)
    else:
        return radial * np.sin(-m * phi)

def calculate_aerial_image_2d(X, Y, na, sigma, wavelength_um, pitch_um, duty_cycle,
                              defocus_um, mask_pattern, polarization="unpolarized", aberrations=None):
    k0 = 1 / wavelength_um
    na_inner = na * (1 - sigma)
    na_outer = na

    grid_size = X.shape[0]
    fx = np.fft.fftfreq(grid_size, d=(X[1, 0] - X[0, 0]))
    fy = np.fft.fftfreq(grid_size, d=(Y[0, 1] - Y[0, 0]))
    FX, FY = np.meshgrid(fx, fy)
    F_rho = np.sqrt(FX**2 + FY**2)
    F_phi = np.arctan2(FY, FX)

    mask_ft = np.fft.fftshift(np.fft.fft2(mask_pattern))

    # Include aberrations using Zernike polynomials
    aberration = np.ones_like(F_rho, dtype=complex)
    if aberrations:
        for n, m, coeff in aberrations:
            aberration += coeff * zernike(n, m, F_rho / na, F_phi)

    # Include polarization effects
    if polarization == "s":
        polarization_factor = np.cos(F_phi)**2
    elif polarization == "p":
        polarization_factor = np.sin(F_phi)**2
    else:
        polarization_factor = np.ones_like(F_phi)

    # Compute the pupil function
    pupil = np.zeros_like(F_rho)
    within_outer_na = F_rho * wavelength_um <= na_outer
    pupil[within_outer_na] = 1.0
    if sigma > 0:
        within_inner_na = F_rho * wavelength_um <= na_inner
        pupil[within_inner_na] = 0.0
        transition_zone = np.logical_and(~within_inner_na, within_outer_na)
        pupil[transition_zone] *= np.clip((F_rho[transition_zone] * wavelength_um - na_inner) / (na_outer - na_inner), 0, 1)

    sqrt_arg = (na / wavelength_um)**2 - (FX**2 + FY**2)
    sqrt_arg[sqrt_arg < 0] = 0  # Avoid negative values under the square root
    phase = -1j * 2 * np.pi * defocus_um * np.sqrt(sqrt_arg)
    otf = pupil * aberration * polarization_factor * np.exp(phase)
    otf[np.isnan(otf)] = 0  # Remove any NaNs resulting from sqrt of negative numbers

    image_ft = mask_ft * otf
    image = np.abs(np.fft.ifft2(np.fft.ifftshift(image_ft))) ** 2

    image -= image.min()
    if image.max() != 0:
        image /= image.max()
    else:
        image = np.zeros_like(image)

    return image

def add_shot_noise(image, dose):
    # Ensure dose is positive to avoid invalid values
    dose = max(dose, 1e-9)
    lam = image * dose
    lam[lam > 1e8] = 1e8  # Limit lam to prevent overflow
    noisy_image = np.random.poisson(lam) / dose
    return noisy_image

def simulate_acid_diffusion(image, diffusion_length_pixels):
    # Apply Gaussian filter for acid diffusion
    return gaussian_filter(image, sigma=diffusion_length_pixels)

# Implement Dill's ABC model for resist bleaching
def simulate_resist_bleaching(exposed_intensity, exposure_time, A, B, C):
    # Simplified model: Transmittance changes with exposure
    transmittance = A * np.exp(-B * exposure_time * exposed_intensity) + C
    return transmittance

# Implement Standing Wave Effects
def apply_standing_wave_effect(intensity, wavelength_um, resist_thickness_um):
    standing_wave_pattern = 0.5 * (1 + np.cos(2 * np.pi * resist_thickness_um / wavelength_um))
    return intensity * standing_wave_pattern

# Simulate PAG Diffusion
def simulate_pag_diffusion(image, pag_diffusion_length_pixels):
    return gaussian_filter(image, sigma=pag_diffusion_length_pixels)

# 3D Resist Profile Simulation
def simulate_resist_profile_3d(developed_image, resist_thickness_nm):
    resist_height = developed_image * resist_thickness_nm
    return resist_height

def measure_critical_dimension(resist_image, pixel_size_nm):
    profile = np.mean(resist_image, axis=0)
    threshold_level = 0.5  # Adjust based on resist_image contrast
    binary_profile = profile > threshold_level
    edges = np.diff(binary_profile.astype(int))
    rising_edges = np.where(edges == 1)[0]
    falling_edges = np.where(edges == -1)[0]
    if len(rising_edges) >= 1 and len(falling_edges) >= 1:
        cd_pixels = falling_edges[0] - rising_edges[0]
        cd_nm = cd_pixels * pixel_size_nm
        return cd_nm
    else:
        return np.nan

def optimize_parameters(target_cd_nm, initial_params, bounds, simulation_params):
    def objective(params):
        na, sigma, dose = params
        (X, Y, wavelength_um, pitch_um, duty_cycle, defocus_um, mask_pattern,
         diffusion_length_nm, threshold, pixel_size_nm, polarization, aberrations) = simulation_params

        if not (bounds[0][0] <= na <= bounds[0][1] and
                bounds[1][0] <= sigma <= bounds[1][1] and
                bounds[2][0] <= dose <= bounds[2][1]):
            return np.inf

        if mask_pattern is None or mask_pattern.size == 0:
            st.warning("Mask pattern is not provided or empty. Generating a random mask pattern.")
            grid_size = X.shape[0]
            mask_pattern = np.random.rand(grid_size, grid_size)

        diffusion_length_pixels = diffusion_length_nm / pixel_size_nm

        intensity = calculate_aerial_image_2d(X, Y, na, sigma, wavelength_um, pitch_um,
                                              duty_cycle, defocus_um, mask_pattern, polarization, aberrations)
        exposed_intensity = add_shot_noise(intensity, dose)
        developed_image = simulate_acid_diffusion(exposed_intensity, diffusion_length_pixels)
        resist_image = developed_image >= threshold
        cd_nm = measure_critical_dimension(resist_image, pixel_size_nm)
        if np.isnan(cd_nm):
            return np.inf
        return np.abs(cd_nm - target_cd_nm)

    result = optimize.minimize(
        objective, initial_params, method='Powell', bounds=bounds
    )
    return result.x

# Streamlit UI
st.title("Advanced Lithography Simulation with Enhanced Physical Models")

# Sidebar inputs for simulation parameters
st.sidebar.header("Simulation Parameters")

mask_type = st.sidebar.selectbox(
    "Mask Type",
    options=["Line-Space Grating", "Contact Hole Array", "Custom Pattern", "Upload Pattern", "GDSII", "OASIS", "SVG"],
    help="Select the type of mask pattern to simulate."
)

wavelength = st.sidebar.number_input(
    "Wavelength (nm)", min_value=10.0, value=193.0, step=1.0,
    help="Exposure wavelength in nanometers."
)
na = st.sidebar.slider(
    "Numerical Aperture (NA)", min_value=0.1, max_value=1.4, value=0.85, step=0.01,
    help="Numerical aperture of the projection system."
)
sigma = st.sidebar.slider(
    "Partial Coherence Factor (σ)", min_value=0.0, max_value=1.0, value=0.7, step=0.01,
    help="Partial coherence factor of the illumination."
)
pitch = st.sidebar.number_input(
    "Pitch (nm)", min_value=50.0, value=180.0, step=1.0,
    help="Pitch of the grating pattern."
)
duty_cycle = st.sidebar.slider(
    "Duty Cycle", min_value=0.1, max_value=0.9, value=0.5, step=0.01,
    help="Ratio of the line width to the pitch."
)
defocus_nm = st.sidebar.slider(
    "Defocus (nm)", min_value=-500.0, max_value=500.0, value=0.0, step=10.0,
    help="Defocus value in nanometers."
)
dose = st.sidebar.slider(
    "Exposure Dose (Relative)", min_value=0.5, max_value=1.5, value=1.0, step=0.01,
    help="Relative exposure dose."
)
threshold = st.sidebar.slider(
    "Resist Threshold", min_value=0.0, max_value=1.0, value=0.3, step=0.01,
    help="Threshold for resist development."
)

# Advanced resist parameters
st.sidebar.header("Advanced Resist Parameters")
diffusion_length_nm = st.sidebar.number_input(
    "Diffusion Length (nm)", min_value=0.0, value=10.0, step=1.0,
    help="Diffusion length during post-exposure bake."
)
resist_thickness_nm = st.sidebar.number_input(
    "Resist Thickness (nm)", min_value=10.0, value=100.0, step=1.0,
    help="Thickness of the photoresist."
)

# Implement Dill's ABC parameters
st.sidebar.subheader("Dill's Parameters for Resist Bleaching")
A = st.sidebar.number_input("A Parameter", value=1.0, help="Dill's A parameter.")
B = st.sidebar.number_input("B Parameter", value=0.0, help="Dill's B parameter.")
C = st.sidebar.number_input("C Parameter", value=0.0, help="Dill's C parameter.")
exposure_time = st.sidebar.number_input("Exposure Time (s)", value=1.0, help="Exposure time in seconds.")

# Add aberration inputs
st.sidebar.header("Aberrations")
aberrations = []  # List to hold (n, m, coefficient) for aberrations
for n in range(5):
    for m in range(-n, n+1, 2):
        if st.sidebar.checkbox(f"Zernike ({n}, {m})"):
            coeff = st.sidebar.slider(f"Coefficient for Z({n}, {m})", -1.0, 1.0, 0.0, 0.01)
            aberrations.append((n, m, coeff))

# Add polarization selection
polarization = st.sidebar.selectbox("Polarization", ["unpolarized", "s", "p"])

# Spatial positions for 2D simulation
grid_size = 500  # Adjust for resolution
x_range_um = 2 * (pitch / 1000.0)  # Define the simulation range
x = np.linspace(-x_range_um, x_range_um, grid_size)
y = np.linspace(-x_range_um, x_range_um, grid_size)
X, Y = np.meshgrid(x, y)

# Pixel size in nm (used for CD measurement)
pixel_size_nm = (x[1] - x[0]) * 1000.0

# Custom pattern input
if mask_type == "Custom Pattern":
    st.sidebar.header("Custom Pattern Input")
    custom_pattern_code = st.sidebar.text_area(
        "Enter custom pattern code (Python code defining 'mask_pattern')",
        value="mask_pattern = np.where((X**2 + Y**2) < (0.5 * (pitch / 1000.0))**2, 1, 0)"
    )
    # Now, execute the custom pattern code
    try:
        exec(custom_pattern_code, globals(), locals())
    except Exception as e:
        st.error(f"Error in custom pattern code: {e}")
        st.stop()
elif mask_type == "Line-Space Grating":
    # Create a line-space grating mask pattern
    mask_pattern = np.where((np.mod(X, pitch / 1000.0)) < (duty_cycle * pitch / 1000.0), 1, 0)
elif mask_type == "Contact Hole Array":
    X_mod = np.mod(X, pitch / 1000.0) - 0.5 * pitch / 1000.0
    Y_mod = np.mod(Y, pitch / 1000.0) - 0.5 * pitch / 1000.0
    hole_radius = 0.5 * duty_cycle * pitch / 1000.0
    mask_pattern = np.where(X_mod**2 + Y_mod**2 < hole_radius**2, 1, 0)
elif mask_type in ["Upload Pattern", "GDSII", "OASIS", "SVG"]:
    uploaded_file = st.sidebar.file_uploader("Upload Mask Pattern (NumPy .npy, GDSII, OASIS, or SVG file)", type=["npy", "gds", "oas", "svg"])
    if uploaded_file is not None:
        if uploaded_file.name.endswith(".npy"):
            try:
                mask_pattern = np.load(uploaded_file)
                if mask_pattern.size == 0:
                    st.warning("Uploaded mask pattern is empty. Generating a random mask pattern.")
                    mask_pattern = np.random.rand(grid_size, grid_size)
            except Exception as e:
                st.error(f"Error loading mask pattern: {e}")
                st.stop()
        elif uploaded_file.name.endswith(".gds"):
            gds_lib = gdspy.GdsLibrary()
            gds_lib.read_gds(uploaded_file)
            mask_pattern = gdspy_to_array(gds_lib, grid_size)
        elif uploaded_file.name.endswith(".svg"):
            mask_pattern = svg_to_array(uploaded_file, grid_size)
        else:
            st.error("Unsupported file type.")
            st.stop()
    else:
        # Generate a random mask pattern if none is provided
        mask_pattern = np.random.rand(grid_size, grid_size)
        st.warning("No mask pattern provided. A random pattern has been generated.")
else:
    # Generate a random mask pattern if none is provided
    mask_pattern = np.random.rand(grid_size, grid_size)
    st.warning("No mask pattern provided. A random pattern has been generated.")

# Preset configurations
preset_configs = {
    "Custom": {},
    "DUV (ArF)": {"wavelength": 193, "na": 0.93, "sigma": 0.75},
    "KrF": {"wavelength": 248, "na": 0.8, "sigma": 0.6},
    "i-line": {"wavelength": 365, "na": 0.6, "sigma": 0.5},
    "EUV": {"wavelength": 13.5, "na": 0.33, "sigma": 0.2},
}

selected_preset = st.sidebar.selectbox("Preset Configurations", list(preset_configs.keys()))

if selected_preset != "Custom":
    preset = preset_configs[selected_preset]
    wavelength = preset.get("wavelength", wavelength)
    na = preset.get("na", na)
    sigma = preset.get("sigma", sigma)

# Convert units to micrometers (um)
wavelength_um = wavelength / 1000.0
pitch_um = pitch / 1000.0
defocus_um = defocus_nm / 1000.0
diffusion_length_um = diffusion_length_nm / 1000.0
resist_thickness_um = resist_thickness_nm / 1000.0

# Calculate the aerial image
intensity = calculate_aerial_image_2d(X, Y, na, sigma, wavelength_um, pitch_um,
                                      duty_cycle, defocus_um, mask_pattern, polarization, aberrations)

# Apply exposure dose and shot noise
exposed_intensity = add_shot_noise(intensity, dose)

# Simulate resist bleaching using Dill's model
transmittance = simulate_resist_bleaching(exposed_intensity, exposure_time, A, B, C)
# Update the intensity after bleaching
exposed_intensity *= transmittance

# Convert diffusion length from nm to pixels
diffusion_length_pixels = diffusion_length_nm / pixel_size_nm

# Apply acid diffusion (chemical amplification model)
developed_image = simulate_acid_diffusion(exposed_intensity, diffusion_length_pixels)

# Apply standing wave effects
standing_wave_intensity = apply_standing_wave_effect(developed_image, wavelength_um, resist_thickness_um)

# Simulate PAG diffusion
pag_diffusion_length_pixels = 5  # Adjust as needed
pag_diffused_image = simulate_pag_diffusion(standing_wave_intensity, pag_diffusion_length_pixels)

# Apply resist threshold
resist_image = pag_diffused_image >= threshold

# 3D Resist Profile Simulation
resist_profile_3d = simulate_resist_profile_3d(resist_image, resist_thickness_nm)

# 3D Visualization using Plotly
fig_3d = go.Figure(data=[
    go.Surface(z=resist_profile_3d, x=x * 1000, y=y * 1000, colorscale='Viridis')
])
fig_3d.update_layout(
    title='3D Resist Profile',
    scene=dict(
        xaxis_title='X Position (nm)',
        yaxis_title='Y Position (nm)',
        zaxis_title='Resist Height (nm)'
    ),
    height=700
)

st.plotly_chart(fig_3d)

# Measure critical dimension
cd_nm = measure_critical_dimension(resist_image, pixel_size_nm)

# Optimization
st.sidebar.header("Optimization Parameters")
if st.sidebar.button("Optimize Parameters"):
    target_cd_nm = st.sidebar.number_input("Target Critical Dimension (nm)", min_value=1.0, value=50.0, step=1.0)
    initial_params = [na, sigma, dose]
    bounds = [(0.1, 1.4), (0.0, 1.0), (0.5, 1.5)]  # Bounds for na, sigma, dose
    simulation_params = (X, Y, wavelength_um, pitch_um, duty_cycle, defocus_um, mask_pattern,
                         diffusion_length_nm, threshold, pixel_size_nm, polarization, aberrations)
    optimized_params = optimize_parameters(target_cd_nm, initial_params, bounds, simulation_params)
    na_opt, sigma_opt, dose_opt = optimized_params
    st.write(f"Optimized parameters: NA={na_opt:.3f}, σ={sigma_opt:.3f}, Dose={dose_opt:.3f}")

    # Re-run simulation with optimized parameters
    intensity = calculate_aerial_image_2d(X, Y, na_opt, sigma_opt, wavelength_um, pitch_um,
                                          duty_cycle, defocus_um, mask_pattern, polarization, aberrations)
    exposed_intensity = add_shot_noise(intensity, dose_opt)
    exposed_intensity *= simulate_resist_bleaching(exposed_intensity, exposure_time, A, B, C)
    developed_image = simulate_acid_diffusion(exposed_intensity, diffusion_length_pixels)
    standing_wave_intensity = apply_standing_wave_effect(developed_image, wavelength_um, resist_thickness_um)
    pag_diffused_image = simulate_pag_diffusion(standing_wave_intensity, pag_diffusion_length_pixels)
    resist_image = pag_diffused_image >= threshold
    cd_nm = measure_critical_dimension(resist_image, pixel_size_nm)
    st.write(f"Measured Critical Dimension after Optimization: {cd_nm:.2f} nm")

# Plotting the results using Plotly for interactive plots
fig = go.Figure()
# Aerial image intensity plot
fig.add_trace(
    go.Heatmap(
        z=intensity,
        x=x * 1000,
        y=y * 1000,
        colorscale='Viridis',
        colorbar=dict(title='Intensity'),
    )
)
fig.update_layout(height=600, width=600, title_text="Aerial Image Intensity")

st.plotly_chart(fig)

# Display critical dimension
st.write(f"Measured Critical Dimension: {cd_nm:.2f} nm")

# Add process window analysis
st.header("Process Window Analysis")
if st.button("Run Process Window Analysis"):
    exposure_range = np.linspace(0.8, 1.2, 5)
    focus_range = np.linspace(-100, 100, 5)

    process_window = np.zeros((len(exposure_range), len(focus_range)))

    for i, exp in enumerate(exposure_range):
        for j, foc in enumerate(focus_range):
            intensity = calculate_aerial_image_2d(X, Y, na, sigma, wavelength_um, pitch_um,
                                                  duty_cycle, foc / 1000.0, mask_pattern, polarization, aberrations)
            exposed_intensity = add_shot_noise(intensity, exp * dose)
            exposed_intensity *= simulate_resist_bleaching(exposed_intensity, exposure_time, A, B, C)
            developed_image = simulate_acid_diffusion(exposed_intensity, diffusion_length_pixels)
            standing_wave_intensity = apply_standing_wave_effect(developed_image, wavelength_um, resist_thickness_um)
            pag_diffused_image = simulate_pag_diffusion(standing_wave_intensity, pag_diffusion_length_pixels)
            resist_image = pag_diffused_image >= threshold
            cd = measure_critical_dimension(resist_image, pixel_size_nm)
            process_window[i, j] = cd

    fig_pw = px.imshow(process_window, x=focus_range, y=exposure_range,
                       labels=dict(x="Focus (nm)", y="Relative Exposure", color="CD (nm)"),
                       title="Process Window")
    st.plotly_chart(fig_pw)

# Provide Data Export Options
if st.button("Download Results"):
    data = {
        'intensity': intensity,
        'exposed_intensity': exposed_intensity,
        'developed_image': developed_image,
        'standing_wave_intensity': standing_wave_intensity,
        'pag_diffused_image': pag_diffused_image,
        'resist_image': resist_image.astype(float),
        'resist_profile_3d': resist_profile_3d
    }
    buf = BytesIO()
    np.savez_compressed(buf, **data)
    buf.seek(0)
    st.download_button(
        label="Download Simulation Data",
        data=buf,
        file_name="lithography_simulation_data.npz",
        mime="application/octet-stream"
    )

# Generate a comprehensive report
if st.button("Generate Report"):
    report = f"""
    # Lithography Simulation Report

    **Mask Type**: {mask_type}
    **Wavelength**: {wavelength} nm
    **Numerical Aperture (NA)**: {na}
    **Partial Coherence Factor (σ)**: {sigma}
    **Pitch**: {pitch} nm
    **Duty Cycle**: {duty_cycle}
    **Defocus**: {defocus_nm} nm
    **Exposure Dose**: {dose}
    **Resist Threshold**: {threshold}
    **Diffusion Length**: {diffusion_length_nm} nm
    **Resist Thickness**: {resist_thickness_nm} nm
    **Dill's Parameters**: A={A}, B={B}, C={C}
    **Exposure Time**: {exposure_time} s
    **Polarization**: {polarization}
    **Aberrations**: {aberrations}
    **Measured Critical Dimension**: {cd_nm:.2f} nm

    ## Process Window Analysis
    The process window analysis shows the critical dimension variation across different focus and exposure conditions. This helps in determining the optimal process conditions and the tolerance for focus and exposure variations.

    ## 3D Resist Profile
    The 3D resist profile visualization provides insights into the resist sidewall angles and how the pattern is formed in the vertical dimension. This is crucial for understanding how the final pattern will look after development and etching.

    ## Optimization Results
    If optimization was performed, the report includes the optimized parameters and the resulting critical dimension. This shows the potential for improving the lithography process by fine-tuning the key parameters.

    ## Conclusions and Recommendations
    Based on the simulation results, here are some key observations and recommendations:
    1. The process window analysis suggests that the optimal focus range is [insert range] nm, and the optimal exposure range is [insert range] times the nominal dose.
    2. The 3D resist profile indicates that the sidewall angle is approximately [insert angle] degrees. This may affect the subsequent etching process and should be considered in the overall process design.
    3. The standing wave effect is [significant/minimal], which may require [adjusting the resist thickness/using bottom anti-reflective coating (BARC)] to minimize its impact.
    4. The impact of aberrations on the final pattern is [significant/minimal]. [If significant, recommend strategies to mitigate aberration effects.]
    5. The polarization effect is [significant/minimal]. [If significant, recommend optimizing illumination polarization.]

    ## Next Steps
    1. Validate these simulation results with experimental data.
    2. Conduct sensitivity analysis to identify the most critical parameters affecting the process.
    3. Explore advanced techniques such as source-mask optimization (SMO) or inverse lithography technology (ILT) to further improve the process window and pattern fidelity.
    4. Consider the impact of these results on the subsequent etching and pattern transfer steps in the overall semiconductor manufacturing process.
    """

    st.markdown(report)
    st.download_button(
        label="Download Report",
        data=report,
        file_name="lithography_simulation_report.md",
        mime="text/markdown"
    )

# Add a section for comparing different parameter sets
st.header("Parameter Comparison")
if st.button("Compare Parameters"):
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("Current Parameters")
        intensity_current = calculate_aerial_image_2d(X, Y, na, sigma, wavelength_um, pitch_um,
                                                      duty_cycle, defocus_um, mask_pattern, polarization, aberrations)
        fig_current = px.imshow(intensity_current, title="Current Parameters")
        st.plotly_chart(fig_current)

    with col2:
        st.subheader("Modified Parameters")
        na_mod = st.slider("NA (Modified)", 0.1, 1.4, na, 0.01)
        sigma_mod = st.slider("Sigma (Modified)", 0.0, 1.0, sigma, 0.01)
        intensity_mod = calculate_aerial_image_2d(X, Y, na_mod, sigma_mod, wavelength_um, pitch_um,
                                                  duty_cycle, defocus_um, mask_pattern, polarization, aberrations)
        fig_mod = px.imshow(intensity_mod, title="Modified Parameters")
        st.plotly_chart(fig_mod)

# Add a section for sensitivity analysis
st.header("Sensitivity Analysis")
if st.button("Run Sensitivity Analysis"):
    parameters = ['NA', 'Sigma', 'Dose', 'Defocus']
    sensitivities = {}

    base_params = {'na': na, 'sigma': sigma, 'dose': dose, 'defocus_um': defocus_um}

    for param in parameters:
        base_value = base_params[param.lower()]
        delta = base_value * 0.05  # 5% change

        # Calculate CD for base value
        intensity_base = calculate_aerial_image_2d(X, Y, na, sigma, wavelength_um, pitch_um,
                                                   duty_cycle, defocus_um, mask_pattern, polarization, aberrations)
        exposed_intensity_base = add_shot_noise(intensity_base, dose)
        developed_image_base = simulate_acid_diffusion(exposed_intensity_base, diffusion_length_pixels)
        resist_image_base = developed_image_base >= threshold
        cd_base = measure_critical_dimension(resist_image_base, pixel_size_nm)

        # Calculate CD for changed value
        changed_params = base_params.copy()
        changed_params[param.lower()] += delta
        intensity_changed = calculate_aerial_image_2d(X, Y, changed_params['na'], changed_params['sigma'], wavelength_um,
                                                      pitch_um, duty_cycle, changed_params['defocus_um'], mask_pattern,
                                                      polarization, aberrations)
        exposed_intensity_changed = add_shot_noise(intensity_changed, changed_params['dose'])
        developed_image_changed = simulate_acid_diffusion(exposed_intensity_changed, diffusion_length_pixels)
        resist_image_changed = developed_image_changed >= threshold
        cd_changed = measure_critical_dimension(resist_image_changed, pixel_size_nm)

        sensitivity = (cd_changed - cd_base) / delta
        sensitivities[param] = sensitivity

    st.write("Sensitivity Analysis Results:")
    for param, sensitivity in sensitivities.items():
        st.write(f"{param}: {sensitivity:.2f} nm/unit")

# Enhance the User Interface with Real-Time Updates
if st.sidebar.button("Run Simulation"):
    st.experimental_rerun()